colordef white
close all


% waitfor(helpdlg("Select all wild type (WT) data files from one date (.mat)."))
% [fileWT, pathWT, indzWT] = uigetfile('*.mat', 'MultiSelect', 'on');
% cd(pathWT)
% cd('..')
% 
% waitfor(helpdlg("Select all heterozygous data files from one date (.mat)."))
% [fileHet, pathHet, indzHet] = uigetfile('*.mat', 'MultiSelect', 'on');
% 
% waitfor(helpdlg("Select all homozygous data files from one date (.mat)."))
% [fileHom, pathHom, indzHom] = uigetfile('*.mat', 'MultiSelect', 'on');
% 
% 
% % Use the filenames directly, no need to define Mouse.wt, Mouse.het, and Mouse.homo explicitly
% Mouse.wt = fileWT;  % List of WT data files
% Mouse.het = fileHet;  % List of HET data files
% Mouse.homo = fileHom;  % List of HOMO data files

% Define groups, sensors, axes, and other parameters
groups = {'wt', 'het', 'homo'};
Sensors = {'Accelerometer', 'Gyroscope'};
AXS = {'Foreaft', 'Lat', 'Vert'; 'Roll', 'Pitch', 'Yaw'};
NFFT = 4096; % Window length
Freq_vector = 0.5:0.5:30;

colors.wt = [0, 0, 0]/255;
colors.het = [41, 170, 225]/255;
colors.homo = [255, 0, 0]/255;

Pxx = [];
bands = [0 30];
n = 0;
xlog = false;
% Loop through each group (wt, het, homo), sensor, and axis
for i = 1:length(groups)
    group = groups{i};
    files = struct2cell(dir(group));
    files = files(1,:);
    files = files(contains(files,'.mat'));
    Mouse.(group) = files;
    for j = 1:length(Mouse.(groups{i}))
        load([groups{i} '/' Mouse.(groups{i}){j}]);
        I = zeros(Data.N, 1);
        for M_index = 1:size(Data.M, 1)
            I(Data.M(M_index, 1):Data.M(M_index, 2)) = 1;
        end
        I = find(I);
        
        for sen = 1:length(Sensors)
            for axs = 1:3
                input = Data.(AXS{sen, axs});
                input = input(I);
                input = input - mean(input,'omitnan');
                
                % Clean input by removing NaN and Inf values
                input = input(isfinite(input));

                % Check if the input is empty after cleaning
                if isempty(input)
                    warning('Input signal is empty after filtering. Skipping this segment.');
                    continue; % Skip to the next iteration if input is empty
                end

                % If the input signal is shorter than the window length, adjust the NFFT
                if length(input) < NFFT
                    NFFT_adjusted = length(input); % Set NFFT to the length of the input
                else
                    NFFT_adjusted = NFFT; % Use default NFFT
                end
                
                % Ensure overlap is smaller than the segment length
                noverlap = min(NFFT_adjusted / 2, NFFT_adjusted - 1); 

                if sen == 2
                    if axs == 2
                        figure(100)
                        hold on
                        plot(movmean(input, 2000) + n, 'color', colors.(groups{i}))
                    elseif axs == 3
                        figure(101)
                        hold on
                        plot(movmean(input, 2000) + n, 'color', colors.(groups{i}))
                        n = n + 15;
                    end
                end
                
                % Plot for each group, sensor, and axis
                figure(i)
                subplot(2, 3, axs + (sen - 1) * 3);
                hold on
                if sen == 1
                    ylabel('Power (g^2.s)')
                else
                    ylabel('Power (deg^2/s)')
                end
                
                % Compute power spectral density (PSD) using Welch's method
%                 [TMP_Pxx, TMP_F] = pwelch(input, bartlett(NFFT_adjusted), noverlap, NFFT_adjusted, 1000);
                [TMP_Pxx, TMP_F] = pwelch(input, bartlett(NFFT_adjusted), noverlap, Freq_vector, 1000, 'power');
                F = TMP_F(TMP_F < 50);
                Pxx.(groups{i}).([Sensors{sen} '_' num2str(axs)])(:, j) = TMP_Pxx(TMP_F < 50);
                
                plot(F, Pxx.(groups{i}).([Sensors{sen} '_' num2str(axs)])(:, j), 'color', colors.(groups{i}))


                if xlog
                    set(gca, 'Xscale', 'log', 'Yscale', 'log')
                else
                    set(gca, 'Yscale', 'log')
                end
                title([AXS{sen, axs}], 'FontSize', 20)
                xlabel('Frequency (Hz)')
            end
        end
    end
    for sen = 1:length(Sensors)
        for axs = 1:3
            for bnd = 1:size(bands, 1)
                ind = find(bands(bnd,1)<F & F<bands(bnd,2));
                Pxx.(groups{i}).(['band' num2str(bnd)]).([Sensors{sen} '_' num2str(axs)]) = ...
                    mean(Pxx.(groups{i}).([Sensors{sen} '_' num2str(axs)])(ind, :),1);
            end
        end
    end

end

% Plot shaded and mean curves for wt, het, and homo groups
figure(200)        
for sen = 1:length(Sensors)
    for axs = 1:3
        subplot(2, 3, axs + (sen - 1) * 3);
        if sen == 1
            ylabel('Power (g^2.s)')
        else
            ylabel('Power (deg^2/s)')
        end
        hold on
%         F(1) = F(2) / 2;
        if xlog
            datax = log10(F);
        else
            datax = F;
        end
        
        for i = 1:length(groups)
            datay = log10(Pxx.(groups{i}).([Sensors{sen} '_' num2str(axs)])');
            mu = mean(datay, 1);
            sd = std(datay) / sqrt(size(datay, 1));
            fill([datax datax(end:-1:1)]', [mu-sd mu(end:-1:1)+sd(end:-1:1)]', colors.(groups{i}), 'FaceAlpha', .2, 'EdgeAlpha', 0)
            plot(datax, mean(log10(Pxx.(groups{i}).([Sensors{sen} '_' num2str(axs)])')), 'color', colors.(groups{i}))
        
        end
        title([AXS{sen, axs}], 'FontSize', 20)
        xlabel('Frequency (Hz)')
    end
end



% Highlight frequency bands of interest (e.g., 0.1 to 1 Hz)
figure(200)
for sen = 1:length(Sensors)
    for axs = 1:3
        subplot(2, 3, axs + (sen - 1) * 3);
        hold on
        for bnd = 1:size(bands, 1)
            tmp = abs(F - bands(bnd, 1));
            f1 = find(tmp == min(tmp), 1);
            tmp = abs(F - bands(bnd, 2));
            f2 = find(tmp == min(tmp), 1);
            AX = axis;
            fill([F(f1) F(f1) F(f2) F(f2) F(f1)],[AX(3) AX(4) AX(4) AX(3) AX(3)],'k',FaceAlpha=.2,EdgeAlpha=0)
            plot(F([f1 f1]), AX(3:4), 'k') % Start of band
            plot(F([f2 f2]), AX(3:4), 'k') % End of band
            ylim(AX(3:4))
        end
    end
end

%%